import numpy as np
import pandas as pd
import collections
import torch
from sklearn.model_selection import StratifiedKFold
from BERT import BertClassifier, train, Dataset
import pickle

df_data = pd.read_csv('./feature.GB18030.v10.csv', encoding='GB18030')
df_features = pd.read_csv('./feature_names.csv')
modes_dict = df_data.mode().iloc[0].to_dict()
df_data.fillna(modes_dict, inplace=True)
feature_names = df_features['feature'].tolist()
df_data.dropna(subset=feature_names+['plain_text'], inplace=True)

min_txt_len = 1
filter2 = df_data['plain_text'].apply(lambda x: len(x) >= min_txt_len)
df_data = df_data[filter2]

binary_labels = []
split_x_high = 9.0
split_y_high = 6.0

split_x_low = 2.0
split_y_low = 2.0

for idx, row in df_data.iterrows():
    if df_data['comment_like_num'][idx] <= 0 and df_data['child_comment_num'][idx] <= 0:
        binary_labels.append(0)
    else:
        binary_labels.append(1)

binary_labels = np.asarray(binary_labels)
print(collections.Counter(binary_labels))

batch_size = 128
LR = 1e-6
epochs = 3
output_dim = 2

kf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
torch.cuda.set_device(2)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pretrained_model = 'hfl/chinese-roberta-wwm-ext'
fold_idx = 0
for train_ids, test_ids in kf.split(df_data['plain_text'], y=binary_labels):
    train_labels = binary_labels[train_ids]
    test_labels = binary_labels[test_ids]
    train_data = df_data['raw_text'][train_ids]
    test_data = df_data['raw_text'][test_ids]
    label_counter = collections.Counter(train_labels)
    label_dist = [label_counter.get(e) for e in range(output_dim)]
    label_weight = [sum(label_dist)/(len(label_dist)*count) for count in label_dist]

    model = BertClassifier(output_dim=output_dim, pretrained_model=pretrained_model)
    train_dataset = Dataset(texts=train_data, labels=train_labels, pretrained_model=pretrained_model)
    test_dataset = Dataset(texts=test_data, labels=test_labels, pretrained_model=pretrained_model)
    train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
    model, total_y_true, total_y_preds = train(model, train_dataloader, val_dataloader, LR, epochs, device, use_cuda=True, label_weight=label_weight)
    fold_idx += 1

    model_name = pretrained_model.replace('/', '-')
    with open(f'./results/llm_binary_{model_name}_{fold_idx}.pkl', 'wb') as f:
        pickle.dump({'y_true': total_y_true, 'y_preds': total_y_preds}, f)
